--- title: Categorical DQN keywords: fastai sidebar: home_sidebar summary: "An implimentation of a DQN that uses distributions to represent Q from the paper A Distributional Perspective on Reinforcement Learning" description: "An implimentation of a DQN that uses distributions to represent Q from the paper A Distributional Perspective on Reinforcement Learning" nb_path: "nbs/10e_agents.dqn.categorical.ipynb" ---
{% raw %}
/opt/conda/lib/python3.8/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at  /opt/conda/conda-bld/pytorch_1616554793803/work/c10/cuda/CUDAFunctions.cpp:109.)
  return torch._C._cuda_getDeviceCount() > 0
{% endraw %} {% raw %}
{% endraw %}

The Categorical DQN can be summarized as:

Instead of action outputs being single Q values, they are instead distributions of `N` size.

We start off with the idea of atoms and supports. A support acts as a mask over the output action distributions. This is illistrated by the equations and the corresponding functions.

We start with the equation...

$$ {\large Z_{\theta}(z,a) = z_i \quad w.p. \: p_i(x,a):= \frac{ e^{\theta_i(x,a)}} {\sum_j{e^{\theta_j(x,a)}}} } $$

... which shows that the end of our neural net model needs to be squished to be a proper probability. It also defines $z_i$ which is a support of which, we will define very soon. Below is the implimentation of the right side equation for $p_i(x,a)$

An important note is that $\frac{ e^{\theta_i(x,a)}} {\sum_j{e^{\theta_j(x,a)}}} $ is just:

{% raw %}
Softmax
torch.nn.modules.activation.Softmax
{% endraw %}

We pretend that the output of the neural net is of shape (batch_sz,n_actions,n_atoms). In this instance, there is only one action. This implies that $Z_{\theta}$ is just $z_0$.

{% raw %}
out=Softmax(dim=1)(torch.randn(1,51,1))[0] # Action 0
plt.plot(out.numpy())
[<matplotlib.lines.Line2D at 0x7fc97b89e940>]
{% endraw %}

The next function describes how propabilities are calculated from the neural net output. The equation describes a $z_i$ which is explained by: $$ \{z_i = V_{min} + i\Delta z : 0 \leq i < N \}, \: \Delta z := \frac{V_{max} - V_{min}}{N - 1} $$

Where $V_{max}$, $V_{min}$, and $N$ are constants that we define. Note that $N$ is the number of atoms. So what does a $z_i$ look like? We will define this in code below...

{% raw %}

create_support[source]

create_support(v_min=-10, v_max=10, n_atoms=51)

Creates the support and returns the z_delta that was used.

{% endraw %} {% raw %}
{% endraw %} {% raw %}
import matplotlib.pyplot as plt

support_dist,z_delta=create_support()
print('z_delta: ',z_delta)
plt.plot(support_dist.numpy())
z_delta:  0.4
[<matplotlib.lines.Line2D at 0x7fc97b79ad00>]
{% endraw %}

This is a single $z_i$ in $Z_{\theta}$. The number of $z_i$s is equal to the number of actions that the DQN is operating with. {% include note.html content='Josiah: Is this always the case? Could there be only $z_0$ and multiple actions?' %} Ok! Hopefully this wasn't too bad to go through. We basically normalized the neural net output to be nicer to deal with, and created/initialized a (bunch) of increasing arrays that we are calling discrete distributions i.e. output from create_support.

Now for the fun part! We have this giant ass update equation:

$$ {\large (\Phi\hat{\mathcal{T}}Z_{\theta}(x,a))_i = \sum_{j=0}^{N-1} \left[ 1 - \frac{ | \mathcal{T}z_j |_{V_{min}}^{V_{max}} - z_i }{ \Delta z } \right]_0^1 p_j(x^{\prime},\pi(x^{\prime})) } $$

Good god... and we also have

$$ \hat{\mathcal{T}}z_j := r + \gamma z_j $$

where, to quote the paper:

"for each atom $z_j$, [and] then distribute its probability $ p_j(x^{\prime},\pi(x^{\prime})) $ to the immediate neighbors of $ \hat{\mathcal{T}}z_j $"

I highly recommend reading pg6 in the paper for a fuller explaination. I was originally wondering what the difference was between $\pi$ and simple $\theta$, which the main difference is that $\pi$ is a greedy action selection i.e. we run argmax to get the action.

This was a lot! Luckily they have a re-formalation in algorithmic form:

{% raw %}
def categorical_update(v_min,v_max,n_atoms,support,delta_z,model,reward,gamma,action,next_state):
    t_q=(support*Softmax(model(next_state).gather(action))).sum()
    a_star=torch.argmax(t_q)
    
    m=torch.zeros((N,)) # m_i = 0 where i in 1,...,N-1
    
    for j in range(n_atoms):
        # Compute the projection of $ \hat{\mathcal{T}}z_j $ onto support $ z_j $
        target_z=torch.clamp(reward+gamma*support[:,j],v_min,v_max)
        b_j=(target_z-v_min)/delta_z # b_j in [0,N-1]
        l=torch.floor(b_j)
        u=torch.ceil(b_j)
        # Distribute probability of $ \hat{\mathcal{T}}z_j $
        m[:,l]=m[:,l]+a_star*(u-b)
        m[:,u]=m[:,u]+a_star*(b-l)
    return # Some cross entropy loss
{% endraw %}

There is a small problem with the above equation. This was a (fairly) literal convertion from Algorithm 1 in the paper to Python. There are some problems here:

  • The current setup doesnt handle batches
  • Some of the variables are a little vague
  • Does not handle terminal states

Lets rename these! We will instead have:
$$ m\_i \rightarrow projection\\ a\_star \rightarrow next\_action\\ b\_j \rightarrow support\_value\\ l \rightarrow support\_left\\ u \rightarrow support\_right\\ $$

So lets revise the problem and pretend that we have a 2 action model, batch size of 8, where the last element has a reward of 0, and where left actions are -1, while right actions are 1.

{% raw %}
from torch.distributions.normal import Normal
{% endraw %}

So for a single action we would have a distribution like this...

{% raw %}
plt.plot(Normal(0,1).sample((51,)).numpy())
[<matplotlib.lines.Line2D at 0x7fc97b71e370>]
{% endraw %}

So since our model has 2 actions that it can pick, we create some distributions for them...

{% raw %}
dist_left=torch.vstack([Normal(0.5,1).sample((1,51)),Normal(0.5,0.1).sample((1,51))]).unsqueeze(0)
dist_right=torch.vstack([Normal(0.5,0.1).sample((1,51)),Normal(0.5,1).sample((1,51))]).unsqueeze(0)
(dist_left.shape,dist_right.shape)
(torch.Size([1, 2, 51]), torch.Size([1, 2, 51]))
{% endraw %}

...where the $[1, 2, 51]$ is $[batch, action, n\_atoms]$

{% raw %}
model_out=torch.vstack([copy([dist_left,dist_right][i%2==0]) for i in range(1,9)]).to(device=default_device())
(model_out.shape)
torch.Size([8, 2, 51])
{% endraw %} {% raw %}
summed_model_out=model_out.sum(dim=2);summed_model_out=Softmax(dim=1)(summed_model_out).to(device=default_device())
(summed_model_out.shape,summed_model_out)
(torch.Size([8, 2]),
 tensor([[0.8571, 0.1429],
         [0.0046, 0.9954],
         [0.8571, 0.1429],
         [0.0046, 0.9954],
         [0.8571, 0.1429],
         [0.0046, 0.9954],
         [0.8571, 0.1429],
         [0.0046, 0.9954]], device='cuda:0'))
{% endraw %}

So when we sum/normalize the distrubtions per batch, per action, we get an output that looks like your typical dqn output...

We can also treat this like a regular DQN and do an argmax to get actions like usual...

{% raw %}
actions=torch.argmax(summed_model_out,dim=1).reshape(-1,1).to(device=default_device());actions
tensor([[0],
        [1],
        [0],
        [1],
        [0],
        [1],
        [0],
        [1]], device='cuda:0')
{% endraw %} {% raw %}
rewards=actions;rewards
tensor([[0],
        [1],
        [0],
        [1],
        [0],
        [1],
        [0],
        [1]], device='cuda:0')
{% endraw %} {% raw %}
dones=Tensor().new_zeros((8,1)).bool().to(device=default_device());dones[-1][0]=1;dones
tensor([[False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [ True]], device='cuda:0')
{% endraw %}

So lets decompose the categorical_update above into something easier to read. First we will note the author's original algorithm:

{% include image.html width="500" height="500" max-width="500" file="/fastrl/docs/images/10e_agents.dqn.categorical_algorithm1.png" %}

We can break this into 3 different functions:

- getting the Q<br>
- calculating the update<br>
- calculating the loss

We will start with the $Q(x_{t+1},a):=\sum_iz_ip_i(x_{t_1},a))$

{% raw %}

class CategoricalDQN[source]

CategoricalDQN(state_sz:int, action_sz:int, n_atoms:int=51, hidden=512, v_min=-10, v_max=10) :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

{% endraw %} {% raw %}
{% endraw %}

The CategoricalDQN.q function gets us 90% of the way to the equation above. However, you will notice that that equation is for a specific action. We will handle this in the actual update function.

{% raw %}
dqn=CategoricalDQN(4,2).to(device=default_device())
dqn(torch.randn(8,4).to(device=default_device())).shape
torch.Size([8, 2, 51])
{% endraw %} {% raw %}
dqn.q(torch.randn(8,4).to(device=default_device()))
tensor([[ 0.1586, -0.0736],
        [-0.0089, -0.1592],
        [ 0.1195, -0.2730],
        [ 0.0953, -0.0631],
        [ 0.0295,  0.0125],
        [ 0.1460, -0.1330],
        [ 0.0863, -0.0585],
        [ 0.1780, -0.1240]], device='cuda:0', grad_fn=<SumBackward1>)
{% endraw %} {% raw %}
dqn.policy(torch.randn(8,4).to(device=default_device()))
tensor([[ 1.5105e-03, -1.1409e-02],
        [ 1.7312e-03, -2.0338e-03],
        [ 1.2366e-03, -7.5805e-04],
        [-2.4870e-04, -1.3664e-03],
        [ 5.1208e-03, -4.6985e-03],
        [ 1.4227e-03,  8.4235e-06],
        [ 8.5255e-05, -2.8415e-03],
        [-1.0086e-03, -1.2995e-03]], device='cuda:0', grad_fn=<MeanBackward1>)
{% endraw %} {% raw %}

distribute[source]

distribute(projection, left, right, support_value, p_a, atom, done)

Does: m_l <- m_l + pj(x{t+1},a*)(u - b_j) operation for non-final states.

{% endraw %} {% raw %}

final_distribute[source]

final_distribute(projection, left, right, support_value, p_a, atom, done)

Does: m_l <- m_l + pj(x{t+1},a*)(u - b_j) operation for final states.

{% endraw %} {% raw %}
{% endraw %} {% raw %}

categorical_update[source]

categorical_update(support, delta_z, q, p, actions, rewards, dones, v_min=-10, v_max=10, n_atoms=51, gamma=0.99, passes=None)

{% endraw %} {% raw %}
{% endraw %} {% raw %}

show_q_distribution[source]

show_q_distribution(cat_dist, title='Update Distributions')

cat_dist being shape: (bs,n_atoms)

{% endraw %} {% raw %}
{% endraw %} {% raw %}
output=categorical_update(dqn.supports,dqn.z_delta,summed_model_out,
                          Softmax(dim=2)(model_out),actions,rewards,dones,passes=None)
show_q_distribution(output)
{% endraw %} {% raw %}
q=dqn.q(torch.randn(8,4).to(device=default_device()))
p=dqn.p(torch.randn(8,4).to(device=default_device()))

output=categorical_update(dqn.supports,dqn.z_delta,q,p,actions,rewards,dones)
show_q_distribution(output,title='Real Model Update Distributions')
{% endraw %} {% raw %}

PartialCrossEntropy[source]

PartialCrossEntropy(p, q)

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class CategoricalDQNTrainer[source]

CategoricalDQNTrainer(n_batch=0, target_sync=300, discount=0.99, n_steps=1) :: Callback

Basic class handling tweaks of the training loop by changing a Learner in various events

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class CategoricalArgMaxFeed[source]

CategoricalArgMaxFeed() :: AgentCallback

Basic class handling tweaks of a callback loop by changing a obj in various events

{% endraw %} {% raw %}
{% endraw %} {% raw %}
dqn=CategoricalDQN(4,2)

agent=Agent(dqn,cbs=[ArgMaxFeed,DiscreteEpsilonRandomSelect])
source=Source(cbs=[GymLoop('CartPole-v1',agent,steps_count=3,seed=0,
                           steps_delta=1),FirstLast])
dls=SourceDataBlock().dataloaders([source],n=100,bs=1,num_workers=0)

learn=Learner(dls,agent,loss_func=PartialCrossEntropy,
              cbs=[ExperienceReplayCallback(bs=32,max_sz=100000,warmup_sz=32),CategoricalDQNTrainer(target_sync=300)],
              metrics=[Reward,Epsilon,NEpisodes])
Could not do one pass in your dataloader, there is something wrong in it
{% endraw %} {% raw %}
slow=False
learn.fit(3 if not slow else 47,lr=0.0001,wd=0)
epoch train_loss train_reward train_epsilon train_n_episodes valid_loss valid_reward valid_epsilon valid_n_episodes time
0 3.103314 21.270588 0.959600 62 00:11
1 3.426107 20.250000 0.919200 178 00:14
2 3.432804 24.750000 0.878800 299 00:16
{% endraw %} {% raw %}
from IPython.display import HTML
import plotly.express as px
{% endraw %} {% raw %}

show_q[source]

show_q(cat_dist, title='Update Distributions')

cat_dist being shape: (bs,n_atoms)

{% endraw %} {% raw %}
{% endraw %} {% raw %}
learn.cbs[-1].local_pred.shape
torch.Size([32, 51])
{% endraw %} {% raw %}
learn.cbs[-1].local_v.shape
torch.Size([32, 2, 51])
{% endraw %} {% raw %}
show_q(learn.cbs[-1].local_yb[0])
{% endraw %} {% raw %}
show_q(learn.cbs[-1].local_pred)
{% endraw %} {% raw %}
show_q(learn.cbs[-1].local_v[:,1,:])
{% endraw %} {% raw %}
show_q(learn.cbs[-1].local_v[:,0,:])
{% endraw %} {% raw %}
(-learn.cbs[-1].local_pred*learn.cbs[-1].local_yb[0]).sum(dim=1).mean()
TensorBatch(3.4515, device='cuda:0', grad_fn=<AliasBackward>)
{% endraw %} {% raw %}
show_q(-learn.cbs[-1].local_pred*learn.cbs[-1].local_yb[0])
{% endraw %} {% raw %}
from IPython.display import HTML
import plotly.express as px

agent=Agent(dqn,cbs=[CategoricalArgMaxFeed,DiscreteEpsilonRandomSelect(min_epsilon=0.0001,max_epsilon=0.0002,epsilon=0.0002)])
source=Src('CartPole-v1',agent,seed=0,steps_count=1,n_envs=1,steps_delta=1,mode='rgb_array',cbs=[GymSrc,FirstLast])


agent=Agent(dqn,cbs=[ArgMaxFeed,DiscreteEpsilonRandomSelect(min_epsilon=0.0001,max_epsilon=0.0002,epsilon=0.0002)])
source=Source(cbs=[GymLoop('CartPole-v1',agent,steps_count=3,seed=0,
                           steps_delta=1),FirstLast])
dls=SourceDataBlock().dataloaders([source],n=100,bs=1,num_workers=0)



exp=[o for o,_ in zip(source,range(50))]

fig = px.imshow(torch.vstack([o['image'] for o in exp]).numpy(),animation_frame=0)
HTML(fig.to_html())
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/tmp/ipykernel_257/2032628216.py in <module>
      3 
      4 agent=Agent(dqn,cbs=[CategoricalArgMaxFeed,DiscreteEpsilonRandomSelect(min_epsilon=0.0001,max_epsilon=0.0002,epsilon=0.0002)])
----> 5 source=Src('CartPole-v1',agent,seed=0,steps_count=1,n_envs=1,steps_delta=1,mode='rgb_array',cbs=[GymSrc,FirstLast])
      6 
      7 exp=[o for o,_ in zip(source,range(50))]

NameError: name 'Src' is not defined
{% endraw %} {% raw %}

show_q_and_max_distribution[source]

show_q_and_max_distribution(cat_dist, title='Update Distributions')

cat_dist being shape: (bs,n_atoms)

{% endraw %} {% raw %}
{% endraw %} {% raw %}
show_q_and_max_distribution(dqn.policy(torch.vstack([o['state'] for o in exp]).to(device=default_device())))
{% endraw %}

If you want to run this using multiple processess, the multiprocessing code looks like below. However you will not be able to run this in a notebook, instead add this to a py file and run it from there.

{% include warning.html content='There is a bug in data block that prevents this. Should be a simple fix.' %}